{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ⚠️ &nbsp; See [NYU-DLSP21](https://github.com/Atcold/NYU-DLSP21) for an updated version &nbsp; ⚠️"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import MNIST\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Displaying routine\n",
    "\n",
    "def display_images(in_, out, n=1, label=None, count=False):\n",
    "    for N in range(n):\n",
    "        if in_ is not None:\n",
    "            in_pic = in_.data.cpu().view(-1, 28, 28)\n",
    "            plt.figure(figsize=(18, 4))\n",
    "            plt.suptitle(label + ' – real test data / reconstructions', color='w', fontsize=16)\n",
    "            for i in range(4):\n",
    "                plt.subplot(1,4,i+1)\n",
    "                plt.imshow(in_pic[i+4*N])\n",
    "                plt.axis('off')\n",
    "        out_pic = out.data.cpu().view(-1, 28, 28)\n",
    "        plt.figure(figsize=(18, 6))\n",
    "        for i in range(4):\n",
    "            plt.subplot(1,4,i+1)\n",
    "            plt.imshow(out_pic[i+4*N])\n",
    "            plt.axis('off')\n",
    "            if count: plt.title(str(4 * N + i), color='w')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set random seeds\n",
    "\n",
    "torch.manual_seed(1)\n",
    "torch.cuda.manual_seed(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define data loading step\n",
    "\n",
    "batch_size = 256\n",
    "\n",
    "kwargs = {'num_workers': 1, 'pin_memory': True}\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    MNIST('./data', train=True, download=True,\n",
    "                   transform=transforms.ToTensor()),\n",
    "    batch_size=batch_size, shuffle=True, **kwargs)\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    MNIST('./data', train=False, transform=transforms.ToTensor()),\n",
    "    batch_size=batch_size, shuffle=True, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Defining the device\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Defining the model\n",
    "\n",
    "d = 20\n",
    "\n",
    "class VAE(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Linear(784, d ** 2),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(d ** 2, d * 2)\n",
    "        )\n",
    "\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.Linear(d, d ** 2),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(d ** 2, 784),\n",
    "            nn.Sigmoid(),\n",
    "        )\n",
    "\n",
    "    def reparameterise(self, mu, logvar):\n",
    "        if self.training:\n",
    "            std = logvar.mul(0.5).exp_()\n",
    "            eps = std.new_empty(std.size()).normal_()\n",
    "            return eps.mul_(std).add_(mu)\n",
    "        else:\n",
    "            return mu\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu_logvar = self.encoder(x.view(-1, 784)).view(-1, 2, d)\n",
    "        mu = mu_logvar[:, 0, :]\n",
    "        logvar = mu_logvar[:, 1, :]\n",
    "        z = self.reparameterise(mu, logvar)\n",
    "        return self.decoder(z), mu, logvar\n",
    "\n",
    "model = VAE().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setting the optimiser\n",
    "\n",
    "learning_rate = 1e-3\n",
    "\n",
    "optimizer = torch.optim.Adam(\n",
    "    model.parameters(),\n",
    "    lr=learning_rate,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reconstruction + β * KL divergence losses summed over all elements and batch\n",
    "\n",
    "def loss_function(x_hat, x, mu, logvar, β=1):\n",
    "    BCE = nn.functional.binary_cross_entropy(\n",
    "        x_hat, x.view(-1, 784), reduction='sum'\n",
    "    )\n",
    "    KLD = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))\n",
    "\n",
    "    return BCE + β * KLD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training and testing the VAE\n",
    "\n",
    "epochs = 10\n",
    "codes = dict(μ=list(), logσ2=list(), y=list())\n",
    "for epoch in range(0, epochs + 1):\n",
    "    # Training\n",
    "    if epoch > 0:  # test untrained net first\n",
    "        model.train()\n",
    "        train_loss = 0\n",
    "        for x, _ in train_loader:\n",
    "            x = x.to(device)\n",
    "            # ===================forward=====================\n",
    "            x_hat, mu, logvar = model(x)\n",
    "            loss = loss_function(x_hat, x, mu, logvar)\n",
    "            train_loss += loss.item()\n",
    "            # ===================backward====================\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        # ===================log========================\n",
    "        print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')\n",
    "    \n",
    "    # Testing\n",
    "    \n",
    "    means, logvars, labels = list(), list(), list()\n",
    "    with torch.no_grad():\n",
    "        model.eval()\n",
    "        test_loss = 0\n",
    "        for x, y in test_loader:\n",
    "            x = x.to(device)\n",
    "            # ===================forward=====================\n",
    "            x_hat, mu, logvar = model(x)\n",
    "            test_loss += loss_function(x_hat, x, mu, logvar).item()\n",
    "            # =====================log=======================\n",
    "            means.append(mu.detach())\n",
    "            logvars.append(logvar.detach())\n",
    "            labels.append(y.detach())\n",
    "    # ===================log========================\n",
    "    codes['μ'].append(torch.cat(means))\n",
    "    codes['logσ2'].append(torch.cat(logvars))\n",
    "    codes['y'].append(torch.cat(labels))\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    print(f'====> Test set loss: {test_loss:.4f}')\n",
    "    display_images(x, x_hat, 1, f'Epoch {epoch}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generating a few samples\n",
    "\n",
    "N = 16\n",
    "z = torch.randn((N, d)).to(device)\n",
    "sample = model.decoder(z)\n",
    "display_images(None, sample, N // 4, count=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display last test batch\n",
    "\n",
    "display_images(None, x, 4, count=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Choose starting and ending point for the interpolation -> shows original and reconstructed\n",
    "\n",
    "A, B = 1, 14\n",
    "sample = model.decoder(torch.stack((mu[A].data, mu[B].data), 0))\n",
    "display_images(None, torch.stack(((\n",
    "    x[A].data.view(-1),\n",
    "    x[B].data.view(-1),\n",
    "    sample.data[0],\n",
    "    sample.data[1]\n",
    ")), 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Perform an interpolation between input A and B, in N steps\n",
    "\n",
    "N = 16\n",
    "code = torch.Tensor(N, 20).to(device)\n",
    "sample = torch.Tensor(N, 28, 28).to(device)\n",
    "for i in range(N):\n",
    "    code[i] = i / (N - 1) * mu[B].data + (1 - i / (N - 1) ) * mu[A].data\n",
    "    # sample[i] = i / (N - 1) * x[B].data + (1 - i / (N - 1) ) * x[A].data\n",
    "sample = model.decoder(code)\n",
    "display_images(None, sample, N // 4, count=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.manifold import TSNE\n",
    "from res.plot_lib import set_default"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_default(figsize=(15, 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, Y, E = list(), list(), list()  # input, classes, embeddings\n",
    "N = 1000  # samples per epoch\n",
    "epochs = (0, 5, 10)\n",
    "for epoch in epochs:\n",
    "    X.append(codes['μ'][epoch][:N])\n",
    "    E.append(TSNE(n_components=2).fit_transform(X[-1].detach().cpu()))\n",
    "    Y.append(codes['y'][epoch][:N])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, a = plt.subplots(ncols=3)\n",
    "for i, e in enumerate(epochs):\n",
    "    s = a[i].scatter(E[i][:,0], E[i][:,1], c=Y[i], cmap='tab10')\n",
    "    a[i].grid(False)\n",
    "    a[i].set_title(f'Epoch {e}')\n",
    "    a[i].axis('equal')\n",
    "f.colorbar(s, ax=a[:], ticks=np.arange(10), boundaries=np.arange(11) - .5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
